import random
from os.path import splitext
from os import listdir
import os.path as osp
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import cv2
import torchvision.transforms.functional as F

class BasicDataset(Dataset):
    def __init__(self, imgs_dir, masks_dir, train_shape=(256,256)):
        self.imgs_dir = imgs_dir
        self.masks_dir = masks_dir
        self.train_shape = train_shape
        self.ids = [splitext(file)[0] for file in listdir(masks_dir)if not file.startswith('.')]


    def __len__(self):
        return len(self.ids)

    def __getitem__(self, i):
        idx = self.ids[i]
        mask_file = osp.join(self.masks_dir,idx+'.png')
        img_file = osp.join(self.imgs_dir,idx+'.jpg')
        img = cv2.imread(img_file)
        mask = cv2.imread(mask_file,0)

        if random.random() < 0.5:
            flip_type = random.choice([-1,0,1])
            img = cv2.flip(img,flip_type)
            mask = cv2.flip(mask,flip_type)

        image,target = cv2.resize(img,self.train_shape) / 255.,cv2.resize(mask,self.train_shape,interpolation=cv2.INTER_NEAREST) / 255.
        image,target = image.astype(np.float32),np.expand_dims(target, axis=0).astype(np.float32)
        image = image.transpose(2, 0, 1)
        image,target = torch.from_numpy(image).type(torch.FloatTensor),torch.from_numpy(target).type(torch.FloatTensor)
        # image = F.normalize(image,mean=self.mean,std=self.std)
        return image,target


if __name__ == '__main__':
    dir_img = r'train_data/images'
    dir_mask = r'train_data/masks'
    dataset = BasicDataset(dir_img, dir_mask)
    train_loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, pin_memory=True)
    for image,target in train_loader:
        print(image.shape,target.shape)